{
 "metadata": {
  "name": "mnist"
 },
 "nbformat": 3,
 "nbformat_minor": 0,
 "worksheets": [
  {
   "cells": [
    {
     "cell_type": "code",
     "collapsed": true,
     "input": [
      "\"\"\"Initialize everything.\"\"\"\n",
      "import cPickle as pickle\n",
      "from decaf import base\n",
      "from decaf.layers import core_layers, regularization, fillers\n",
      "from decaf.layers.data import mnist\n",
      "from decaf.opt import core_solvers\n",
      "from decaf.util import visualize\n",
      "import logging\n",
      "from matplotlib import pyplot\n",
      "import numpy as np\n",
      "\n",
      "logging.getLogger().setLevel(logging.INFO)\n",
      "logging.info('Initialization done.')\n",
      "\n",
      "np.random.seed(1701)\n",
      "ROOT = '/u/vis/x1/common/mnist'\n",
      "MINIBATCH=128\n",
      "METHOD='sgd'\n",
      "MAX_ITER=2000"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# Define LeNet: the magic to recognize images.\n",
      "def lenet():\n",
      "    return [\n",
      "        core_layers.ConvolutionLayer(\n",
      "            name='conv1', num_kernels=20, ksize=5, stride=1, mode='valid',\n",
      "            filler=fillers.XavierFiller()),\n",
      "        core_layers.PoolingLayer(\n",
      "            name='pool1', psize=2, mode='max'),\n",
      "        core_layers.ConvolutionLayer(\n",
      "            name='conv2', num_kernels=50, ksize=5, stride=1, mode='valid',\n",
      "            filler=fillers.XavierFiller()),\n",
      "        core_layers.PoolingLayer(\n",
      "            name='pool2', psize=2, mode='max'),\n",
      "        core_layers.FlattenLayer(name='flatten'),\n",
      "        core_layers.InnerProductLayer(\n",
      "            name='ip1', num_output=500,\n",
      "            filler=fillers.XavierFiller(),\n",
      "            bias_filler=fillers.ConstantFiller(value=0.1)),\n",
      "        core_layers.ReLULayer(name='relu1'),\n",
      "        core_layers.InnerProductLayer(\n",
      "            name='ip2', num_output=10,\n",
      "            filler=fillers.XavierFiller()),\n",
      "        ]"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": true,
     "input": [
      "# The main function to carry out lenet training\n",
      "def main():\n",
      "    logging.getLogger().setLevel(logging.INFO)\n",
      "    ######################################\n",
      "    # First, let's create the decaf layer.\n",
      "    ######################################\n",
      "    logging.info('Loading data and creating the network...')\n",
      "    decaf_net = base.Net()\n",
      "    # add data layer\n",
      "    dataset = mnist.MNISTDataLayer(\n",
      "        name='mnist', rootfolder=ROOT, is_training=True, dtype=np.float64)\n",
      "    decaf_net.add_layer(dataset,\n",
      "                        provides=['image-all', 'label-all'])\n",
      "    # add minibatch layer for stochastic optimization\n",
      "    minibatch_layer = core_layers.BasicMinibatchLayer(\n",
      "        name='batch', minibatch=MINIBATCH)\n",
      "    decaf_net.add_layer(minibatch_layer,\n",
      "                        needs=['image-all', 'label-all'],\n",
      "                        provides=['image', 'label'])\n",
      "    # add lemet\n",
      "    decaf_net.add_layers(lenet(), needs='image', provides='prediction')\n",
      "    # add loss layer\n",
      "    loss_layer = core_layers.MultinomialLogisticLossLayer(\n",
      "        name='loss')\n",
      "    decaf_net.add_layer(loss_layer,\n",
      "                        needs=['prediction', 'label'])\n",
      "    # finish.\n",
      "    decaf_net.finish()\n",
      "    ####################################\n",
      "    # Decaf layer finished construction!\n",
      "    ####################################\n",
      "    \n",
      "    logging.info('Optimizing')\n",
      "    solver = core_solvers.SGDSolver(base_lr=0.01, lr_policy='inv',\n",
      "                                    gamma=0.001, power=0.75, momentum=0.9,\n",
      "                                    max_iter=MAX_ITER)\n",
      "    solver.solve(decaf_net)\n",
      "    # visualize.draw_net_to_file(decaf_net, 'mnist_lenet.png')\n",
      "    decaf_net.save('mnist_lenet.decafnet')\n",
      "\n",
      "    ##############################################\n",
      "    # Now, let's load the net and run predictions \n",
      "    ##############################################\n",
      "    prediction_net = base.Net.load('mnist_lenet.decafnet')\n",
      "    # obtain the test data.\n",
      "    dataset_test = mnist.MNISTDataLayer(\n",
      "        name='mnist', rootfolder=ROOT, is_training=False, dtype=np.float64)\n",
      "    test_image = base.Blob()\n",
      "    test_label = base.Blob()\n",
      "    dataset_test.forward([], [test_image, test_label])\n",
      "    # Run the net.\n",
      "    pred = prediction_net.predict(image=test_image)['prediction']\n",
      "    accuracy = (pred.argmax(1) == test_label.data()).sum() / float(test_label.data().size)\n",
      "    print 'Testing accuracy:', accuracy\n",
      "    print 'Done.'\n",
      "    return decaf_net"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": true,
     "input": [
      "# Run Forrest, run!\n",
      "decaf_net = main()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# Sneak peek\n",
      "print 'Blobs: ', decaf_net.blobs.keys()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# Now, let's look at a training batch.\n",
      "pyplot.figure()\n",
      "_ = visualize.show_multiple(decaf_net.blobs['image'].data()[:32].reshape(32,28,28))"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# First layer of magic!\n",
      "conv1_out = decaf_net.blobs['_decaf_conv1_out'].data()\n",
      "pyplot.figure()\n",
      "_ = visualize.show_channels(conv1_out[0])\n",
      "pyplot.set_cmap(pyplot.cm.RdBu_r)\n",
      "pyplot.figure()\n",
      "_ = visualize.show_channels(conv1_out[1])\n",
      "pyplot.set_cmap(pyplot.cm.RdBu_r)\n",
      "pyplot.figure()\n",
      "_ = visualize.show_channels(conv1_out[2])\n",
      "pyplot.set_cmap(pyplot.cm.RdBu_r)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# Second layer of magic!\n",
      "pyplot.figure()\n",
      "_ = visualize.show_channels(decaf_net.blobs['_decaf_conv2_out'].data()[0])\n",
      "pyplot.set_cmap(pyplot.cm.RdBu_r)\n",
      "pyplot.figure()\n",
      "_ = visualize.show_channels(decaf_net.blobs['_decaf_conv2_out'].data()[1])\n",
      "pyplot.set_cmap(pyplot.cm.RdBu_r)\n",
      "pyplot.figure()\n",
      "_ = visualize.show_channels(decaf_net.blobs['_decaf_conv2_out'].data()[2])\n",
      "pyplot.set_cmap(pyplot.cm.RdBu_r)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# First flat output!\n",
      "pyplot.figure()\n",
      "_ = pyplot.imshow(decaf_net.blobs['_decaf_relu1_out'].data()[:32],\n",
      "                  aspect='auto', interpolation='nearest', cmap=pyplot.cm.RdBu_r)\n",
      "_ = pyplot.colorbar()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# Second flat output!\n",
      "pyplot.figure()\n",
      "_ = pyplot.imshow(decaf_net.blobs['prediction'].data()[:32], interpolation='nearest', cmap=pyplot.cm.RdBu_r)\n",
      "_ = pyplot.colorbar()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# Let's see if we have predicted things well. Usually, on a training batch,\n",
      "# we will have perfect predictions.\n",
      "print 'prediction:', decaf_net.blobs['prediction'].data().argmax(1)[:32]\n",
      "print 'label:     ', decaf_net.blobs['label'].data()[:32]"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# If you did not specify inline output, this will show all the plots.\n",
      "pyplot.show()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [],
     "language": "python",
     "metadata": {},
     "outputs": []
    }
   ],
   "metadata": {}
  }
 ]
}